package cn.zhangfusheng.elasticsearch.transactional;

import cn.zhangfusheng.elasticsearch.exception.GlobalSystemException;
import cn.zhangfusheng.elasticsearch.thread.ThreadLocalDetail;
import lombok.Getter;
import lombok.experimental.Accessors;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.RethrottleRequest;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Objects;

/**
 * 事物控制
 * @author fusheng.zhang
 * @date 2022-03-01 20:38:59
 */
@Getter
@Accessors(chain = true)
public class TransactionalControl {

    public TransactionalControl(Class<? extends Throwable> rollbackFor, TaskId taskId) {
        this.rollbackFor = rollbackFor;
        this.taskId = taskId;
        this.bulkRequest = new BulkRequest();
    }

    private final Class<? extends Throwable> rollbackFor;

    private final TaskId taskId;

    private final BulkRequest bulkRequest;

    /**
     * 请求的个数
     */
    private int requestNum;

    /**
     * task num
     */
    private int taskNum;

    /**
     * 已预执行数
     */
    private int waitExecute;


    public boolean addRequest(DocWriteRequest<?>... requests) {
        if (Objects.nonNull(requests) && requests.length > 0) {
            synchronized (this) {
                bulkRequest.add(requests);
                requestNum += requests.length;
            }
        }
        return true;
    }

    public int addWaitExecute(int num) {
        synchronized (this) {
            this.waitExecute += num;
            return this.waitExecute;
        }
    }

    public TaskId getTaskId() {
        synchronized (this) {
            this.taskNum += 1;
        }
        return taskId;
    }

    public boolean rollbackFor(Throwable throwable) {
        return Objects.nonNull(throwable) && rollbackFor.isInstance(throwable);
    }

    /**
     * 提交
     * @param throwable
     * @param restHighLevelClient
     * @throws IOException
     */
    public void commit(Throwable throwable, RestHighLevelClient restHighLevelClient) throws IOException {
        if (this.rollbackFor(throwable)) return;
        int waitExecute = this.addWaitExecute(-1);
        if (waitExecute == 0) {
            // 执行批量请求
            if (this.getRequestNum() > 0) {
                BulkRequest bulkRequest = this.getBulkRequest();
                ThreadLocalDetail.getRefreshPolicy().ifPresent(bulkRequest::setRefreshPolicy);
                BulkResponse bulkResponse = restHighLevelClient.bulk(bulkRequest, ThreadLocalDetail.requestOptions());
                if (bulkResponse.hasFailures()) {
                    throw new GlobalSystemException(bulkResponse.buildFailureMessage());
                }
            }
            // 执行 taskId 请求
            if (this.getTaskNum() > 0) {
                RethrottleRequest request = new RethrottleRequest(this.getTaskId());
                restHighLevelClient.reindexRethrottle(request, ThreadLocalDetail.requestOptions());
                restHighLevelClient.updateByQueryRethrottle(request, ThreadLocalDetail.requestOptions());
                restHighLevelClient.deleteByQueryRethrottle(request, ThreadLocalDetail.requestOptions());
            }
            // 清除
            ThreadLocalDetail.remove();
        }
    }
}
